package selector

import (
	"context"
	"errors"
	"math/rand"
	"sync"
	"sync/atomic"
	"time"
)

// 加权随机
type WeightRandomBanlance struct {
	nodeCount   int   // 节点数量
	totalWeight int64 // 总权重
	nodes       atomic.Value
	mux         sync.Mutex
}

type WeightRandomBuilder struct {
	wr *WeightRandomBanlance
}

func NewWeightRandomBuilder() Builder {
	return &WeightRandomBanlance{}
}

func (wr *WeightRandomBanlance) Build() Selector {
	return NewWeightRandomBanlance()
}

func NewWeightRandomBanlance() Selector {
	return &WeightRandomBanlance{
		nodeCount:   0,
		totalWeight: 0,
	}
}

func (wr *WeightRandomBanlance) Add(nodes ...Node) error {
	wr.mux.Lock()
	defer wr.mux.Unlock()
	wr.nodes.Store(nodes)
	wr.nodeCount = len(nodes)
	wr.totalWeight = 0
	return wr.init()
}

// init 初始化节点信息
func (wr *WeightRandomBanlance) init() error {
	n := wr.nodes
	if n.Load() == nil {
		return errors.New("NODE_NOT_FOUND")
	}
	nodes := n.Load().([]Node)
	for _, node := range nodes {
		wr.totalWeight += node.InitialWeight()
	}
	return nil
}

func (wr *WeightRandomBanlance) Next(ctx context.Context) (selected Node, err error) {
	wr.mux.Lock()
	defer wr.mux.Unlock()
	n := wr.nodes
	if n.Load() == nil {
		return nil, errors.New("NODE_NOT_FOUND")
	}
	if wr.totalWeight <= 0 {
		return nil, errors.New("wr.totalWeight <= 0")
	}
	rand.Seed(time.Now().UnixNano())
	nodes := n.Load().([]Node)
	weight := rand.Int63n(wr.totalWeight)
	for _, node := range nodes {
		weight -= node.InitialWeight()
		if weight < 0 {
			peer, ok := FromPeerContext(ctx)
			if ok {
				peer.Node = node
			}
			return node, nil
		}
	}

	return nil, errors.New("unkown err")

}
